// Copyright 2022 The Google Research Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef THIRD_PARTY_DNN_ACCELERATOR_HLS_SRC_SOFTMAX_H_
#define THIRD_PARTY_DNN_ACCELERATOR_HLS_SRC_SOFTMAX_H_

#include <ac_math.h>
#include <mc_connections.h>
#include <systemc.h>

#include "src/AccelTypes.h"
#include "src/VectorUtils.h"

template <typename DTYPE, int WIDTH>
SC_MODULE(SoftmaxUnit) {
  sc_in<bool> CCS_INIT_S1(clk);
  sc_in<bool> CCS_INIT_S1(rstn);

  Connections::In<VectorParams> CCS_INIT_S1(paramsIn);
  Connections::In<Pack1D<DTYPE, WIDTH> > CCS_INIT_S1(vectorIn);
  Connections::Out<Pack1D<DTYPE, WIDTH> > CCS_INIT_S1(vectorOut);

  SC_CTOR(SoftmaxUnit) {
    SC_THREAD(run);
    sensitive << clk.pos();
    async_reset_signal_is(rstn, false);
  }

  void run() {
    paramsIn.Reset();
    vectorIn.Reset();
    vectorOut.Reset();

    while (true) {
      VectorParams params = paramsIn.Pop();

      int Y = params.loops[1][params.yLoopIndex[1]];
      int X = params.loops[1][params.xLoopIndex[1]];

      for (int y = 0; y < Y; y++) {
        // calculate max(x)
        DTYPE max = 0;
#pragma hls_pipeline_init_interval 1
        for (int x = 0; x < X / WIDTH; x++) {
          Pack1D<DTYPE, WIDTH> vec = vectorIn.Pop();
          reduceMax<DTYPE, WIDTH>(vec, max);
        }

        // calculate sum(exp(x-max))
        DTYPE sum = 0;
        for (int x = 0; x < X / WIDTH; x++) {
          Pack1D<DTYPE, WIDTH> vec = vectorIn.Pop();
          vectorSub<DTYPE, WIDTH>(vec, max);
          vectorExp<DTYPE, WIDTH>(vec);
          reduceSum<DTYPE, WIDTH>(vec, sum);
        }

        // calculate exp(x - max) /  sum
        for (int x = 0; x < X / WIDTH; x++) {
          Pack1D<DTYPE, WIDTH> vec = vectorIn.Pop();
          vectorSub<DTYPE, WIDTH>(vec, max);
          vectorExp<DTYPE, WIDTH>(vec);
          vectorDiv<DTYPE, WIDTH>(vec, sum);
          vectorOut.Push(vec);
        }
      }
    }
  }
};

#endif  // THIRD_PARTY_DNN_ACCELERATOR_HLS_SRC_SOFTMAX_H_
